SageMaker Sparkを使ったXGBoostでの手書き数字分類をやってみた
SageMaker Sparkを使用して、Amazon SageMaker上で手書き数字の分類モデルを作成するノートブックをやってみたので、その内容を紹介します。
概要
SageMaker Sparkを使って、Sparkをローカル上に展開し、MNISTのデータを読み込み、SageMaker上でXGBoostの分類モデルの学習とホストを行います。その後、作成したモデルを使ってテストデータの分類を行います。
- SageMaker Spark
- Apache SparkのSageMaker用のオープンソースのライブラリです。Sparkに読み込んだデータを使ったSageMaker上での学習や、SparkのMLlibとSageMakerを連携させること等が出来ます。
- MNISTのデータセット
- 手書き数字画像とラベルデータを含んだデータセットです。データ分析のチュートリアルでよく使われるデータセットの一つです。
- XGBoost
- 勾配ブースティングツリーという理論のオープンソースの実装で、分類や回帰によく使われる機械学習アルゴリズムの一つです。
基本的にはawslabsのノートブックに沿って進めますが、一部変更している箇所があります。
やってみた
ノートブックの作成
SageMakerのノートブックインスタンスを立ち上げて、
SageMaker Examples
↓
Sagemaker Spark
↓
pyspark_mnist_xgboost.ipynb
↓
use
でサンプルからノートブックをコピーして、開きます。
ノートブックインスタンスの作成についてはこちらをご参照ください。
セットアップ
IAMロールの取得とSparkセッションの構築を行います。 今回はローカル上にSparkアプリケーションを実行し、セッションに繋ぎます。 リモートのSparkクラスタと接続する場合はAmazon SageMaker PySparkのGitHubでの解説をご参照ください。
import os from pyspark import SparkContext, SparkConf from pyspark.sql import SparkSession import sagemaker from sagemaker import get_execution_role import sagemaker_pyspark # 実行しているIAMロールを取得する role = get_execution_role() # SageMaker Sparkが依存するjarを取得する jars = sagemaker_pyspark.classpath_jars() classpath = ":".join(sagemaker_pyspark.classpath_jars()) # ローカルで実行するSparkアプリケーションとのセッションを取得する spark = SparkSession.builder.config("spark.driver.extraClassPath", classpath)\ .master("local[*]").getOrCreate()
データの読み込み
MNISTのデータセットをSpark Dataframeに読み込みます。 データセットはLibSVM形式のものがS3上で公開されています。 LibSVM形式についてはこちらのエントリをご覧ください。
学習と推論のデータセットと使用するには2種類のカラムが必要です。一つはダブル型のカラム(デフォルトではlabel
という名前)、二つ目はダブル型のベクトル(デフォルトではfeatures
という名前)です。
import boto3 region = boto3.Session().region_name trainingData = spark.read.format('libsvm')\ .option('numFeatures', '784')\ .option('vectorType', 'dense')\ .load('s3a://sagemaker-sample-data-{}/spark/mnist/train/'.format(region)) testData = spark.read.format('libsvm')\ .option('numFeatures', '784')\ .option('vectorType', 'dense')\ .load('s3a://sagemaker-sample-data-{}/spark/mnist/test/'.format(region)) trainingData.show()
モデルの学習とホスティング
Estimatorを作成し、学習とモデルをホストするエンドポイントの作成を行います。 Estimatorには学習とエンドポイントのパラメータ、学習時のハイパーパラメータを設定します。 XGBoostのハイパーパラメータについてはドキュメントをご参照ください。
import random from sagemaker_pyspark import IAMRole, S3DataPath from sagemaker_pyspark.algorithms import XGBoostSageMakerEstimator from sagemaker_pyspark.S3Resources import S3DataPath # モデルアーティファクトの出力先を指定(ノートブックにはなく、追加した内容です。) output_s3_data = S3DataPath('bucket_name', 'sagemaker/spark-xgb/output') # estimatorの設定 xgboost_estimator = XGBoostSageMakerEstimator( sagemakerRole=IAMRole(role), # 学習時とエンドポイントの作成時に使用するIAMロール trainingInstanceType='ml.m4.xlarge', # 学習に使用するインスタンスタイプ trainingInstanceCount=1, # 学習に使用するインスタンス数 endpointInstanceType='ml.m4.xlarge', # モデルをホストするエンドポイントのインスタンスタイプ endpointInitialInstanceCount=1, # モデルをホストするエンドポイントのインスタンス数 trainingOutputS3DataPath = output_s3_data) # モデルアーティファクトの出力先 # ハイパーパラメータの設定 xgboost_estimator.setEta(0.2) xgboost_estimator.setGamma(4) xgboost_estimator.setMinChildWeight(6) xgboost_estimator.setSilent(0) xgboost_estimator.setObjective("multi:softmax") xgboost_estimator.setNumClasses(10) xgboost_estimator.setNumRound(10) # 学習処理開始 model = xgboost_estimator.fit(trainingData)
推論
テストデータをまとめて分類(推論)します。この処理を行うために内部的には、featureカラムをLibSVM形式に変換して、モデルをホストするエンドポイントに投げます。その後、エンドポイントでモデルが分類した結果をCSV形式で受け取って、DataFrameに格納します。この処理をtransform
がまとめてやってくれます。便利です。
※ ノートブックではtrainDataを読み込んでいますが、testDataの方が適切だと思われるので、testDataを読み込むように変更しています。
transformedData = model.transform(testData) transformedData.show()
入力データに加えてprediction
カラムが追加されています。
分類結果ごとに入力画像を表示してみて、正しく分類できたかを確認します。
from pyspark.sql.types import DoubleType import matplotlib.pyplot as plt import numpy as np # 数字を表示するための補助関数 def show_digit(img, caption='', xlabel='', subplot=None): if subplot==None: _,(subplot)=plt.subplots(1,1) imgr=img.reshape((28,28)) subplot.axes.get_xaxis().set_ticks([]) subplot.axes.get_yaxis().set_ticks([]) plt.title(caption) plt.xlabel(xlabel) subplot.imshow(imgr, cmap='gray') # 入力画像データを取得 images = np.array(transformedData.select("features").cache().take(250)) # 分類結果を取得 clusters = transformedData.select("prediction").cache().take(250) # 分類結果ごとに入力画像を表示する for cluster in range(10): print('\n\n\nCluster {}:'.format(int(cluster))) digits=[ img for l, img in zip(clusters, images) if int(l.prediction) == cluster ] height=((len(digits) - 1) // 5) + 1 width=5 plt.rcParams["figure.figsize"] = (width,height) _, subplots = plt.subplots(height, width) subplots=np.ndarray.flatten(subplots) for subplot, image in zip(subplots, digits): show_digit(image, subplot=subplot) for subplot in subplots[len(digits):]: subplot.axis('off') plt.show()
実際には0~9までの数字に対して結果が表示されますが、ここでは0と9の結果だけ紹介します。 0に分類された入力画像は全て正しく0を表していそうです。
9に分類された入力画像は基本的に正しそうですが、7や8が紛れ込んでいます。
概ね高い精度で分類出来ていそうです。
エンドポイントの削除
余計な費用が掛からないようにエンドポイントを削除します。
from sagemaker_pyspark import SageMakerResourceCleanup resource_cleanup = SageMakerResourceCleanup(model.sagemakerClient) resource_cleanup.deleteResources(model.getCreatedResources())
おわりに
SageMaker Sparkライブラリを使用して、Amazon SageMaker上で手書き数字の分類を行えました。SageMaker Sparkを使うことで、Amazon SageMakerのモデルをSpark上で扱うことができます。大きなデータを扱ったり、Apache SparkのMLlibとAmazon SageMakerを連携させる際には非常に便利そうです。
これからAmazon SageMakerとApache Sparkを試してみたいと思っている方の参考になれば幸いです。 最後までお読み頂き有難うございましたー!